-
Notifications
You must be signed in to change notification settings - Fork 10.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : refactor llama_kv_cache, llama_context and llm_build_context #11213
base: master
Are you sure you want to change the base?
Conversation
I am thinking about the following API change for this PR: // API on `master`
DEPRECATED(LLAMA_API void llama_kv_cache_clear(ctx));
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(ctx));
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(ctx));
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(ctx));
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(ctx));
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(ctx));
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(ctx));
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(ctx));
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(ctx));
DEPRECATED(LLAMA_API void llama_kv_cache_update(ctx));
// works with `ctx.kv_self` - backwards compatible with `master`
LLAMA_API void llama_kv_self_clear(ctx);
LLAMA_API bool llama_kv_self_seq_rm(ctx);
LLAMA_API void llama_kv_self_seq_cp(ctx);
LLAMA_API void llama_kv_self_seq_keep(ctx);
LLAMA_API void llama_kv_self_seq_add(ctx);
LLAMA_API void llama_kv_self_seq_div(ctx);
LLAMA_API llama_pos llama_kv_self_seq_pos_max(ctx);
LLAMA_API void llama_kv_self_defrag(ctx);
LLAMA_API bool llama_kv_self_can_shift(ctx);
LLAMA_API void llama_kv_self_update(ctx);
// TODO: llama_kv_cache API
// can be implemented in a later PR
// new API to access the KV cache instance
struct llama_kv_cache;
LLAMA_API struct llama_kv_cache * llama_get_kv_self(ctx)
LLAMA_API void llama_set_kv_self(ctx, kv);
// allow to clone, free, save, load the kv cache |
bcfda5c
to
fb74024
Compare
28f1272
to
9027f32
Compare
src/llama.cpp
Outdated
void llama_kv_self_update(llama_context * ctx) { | ||
llama_kv_self_update_impl(*ctx); | ||
const bool need_reserve = ctx->kv_self_update(); | ||
|
||
// reserve a worst case graph again | ||
if (need_reserve) { | ||
// TODO: extract to a function | ||
const auto & cparams = ctx->cparams; | ||
const auto & model = ctx->model; | ||
|
||
// build worst-case graph | ||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences | ||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); | ||
|
||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph | ||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; | ||
|
||
ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true); | ||
|
||
// initialize scheduler with the worst-case graph | ||
ggml_backend_sched_reset(ctx->sched.get()); | ||
if (!ggml_backend_sched_reserve(ctx->sched.get(), gf)) { | ||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@slaren If we have a separate scheduler for the kv_self
updates (such as K-shift and defrag), would this worst-case reservation be necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, but that would increase memory usage.
@@ -460,8 +461,9 @@ extern "C" { | |||
|
|||
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); | |||
|
|||
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); | |||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); | |||
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama_model
should always be immutable, otherwise it would be hard to guarantee the thread-safety when used in multiple contexts. So returning a const should be correct here.
src/llama.cpp
Outdated
void llama_kv_self_update(llama_context * ctx) { | ||
llama_kv_self_update_impl(*ctx); | ||
const bool need_reserve = ctx->kv_self_update(); | ||
|
||
// reserve a worst case graph again | ||
if (need_reserve) { | ||
// TODO: extract to a function | ||
const auto & cparams = ctx->cparams; | ||
const auto & model = ctx->model; | ||
|
||
// build worst-case graph | ||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences | ||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); | ||
|
||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph | ||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; | ||
|
||
ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true); | ||
|
||
// initialize scheduler with the worst-case graph | ||
ggml_backend_sched_reset(ctx->sched.get()); | ||
if (!ggml_backend_sched_reserve(ctx->sched.get(), gf)) { | ||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, but that would increase memory usage.
@slaren Coming back to your comment from earlier: #11110 (review)
In the OP I have outlined a possible approach to make the implementation more abstract. I have focused primarily on the abstraction of the KV cache and the llama context. If I understand correctly your suggestion, the idea is to have the compute graph build functions for each of the arches (e.g. |
I haven't really though enough about this to make specific suggestions, but I think the goal should be to have an interface that can be used to define everything necessary to implement a model architecture. Ideally, to add support for a new architecture, it should only be necessary be to define a new class and create a mapping between the architecture name in the GGUF file and this class. There may of course be more classes in the interface, but there should be a single entry point. So this should include more than just the graph build function, it should also include all the functions to load a model, create a context, and everything else that may be necessary to run a model. This interface would also need to be supported by other interfaces such as the KV cache abstraction, and graph building helper functions that are currently in To do this, I think it would be better to create an abstract interface that contains everything necessary to define a model architecture. I think that's likely to result in a cleaner and more maintainable codebase than using This is of course a very high level suggestion, it will take a lot of work to define all the details. |
60106c6
to
e1aaa5e
Compare
e1aaa5e
to
a47d389
Compare
Thanks for the suggestions. I'll aim to create the abstract model interface and restructure the implementation so that the |
ggml-ci
ggml-ci
8bc4a9b
to
ad870c4
Compare
This PR is getting close to completion. Here is an update of the new software architecture:
Pinging @MollySophia and @compilade if you could run some tests with this branch to check if the RWKV and Mamba models work correctly. Any suggestions for improving the code are welcome. Hoping to have this ready for review in the next few days. |
I've been quite on/off recently, but hopefully I can have a deeper look into this during the weekend. |
ggml-ci
ggml-ci
@ggerganov I see that there is an implicit assumption in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall. Some points I'm thinking about for my vision PR:
- Having a derived class
llama_vision_context : llama_context
as you said - Input image tokens will be obtained via
llama_batch_ext
, they will be passed tollama_vision_context::input_set
which can work with pixel values instead of text token - Output tensor will be saved to
llama_context::embd_tensor
==> need to add this to the base class
case LLM_ARCH_BERT: | ||
case LLM_ARCH_JINA_BERT_V2: | ||
case LLM_ARCH_NOMIC_BERT: | ||
ctx = new llama_context_enc(*model, params, LLAMA_GRAPH_TYPE_DEFAULT); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be LLAMA_GRAPH_TYPE_ENCODER
? (Though I know it's not currently in used)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, likely will change to LLAMA_GRAPH_TYPE_ENCODER
to be more explicit, though the idea is to have a "default" graph for each model, which for this case would not make a difference using either "default" or "encoder".
cparams.offload_kqv = params.offload_kqv; | ||
cparams.flash_attn = params.flash_attn; | ||
cparams.no_perf = params.no_perf; | ||
cparams.pooling_type = params.pooling_type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we can further split these into "environment" params and "inference" params. AFAIK YaRN/RoPE is not used on recurrent models (and plus, vision encoder usually use learned embeddings)
For example:
// "environment" params, meaning it may affect performance, but does not change the result
cparams.n_seq_max = std::max(1u, params.n_seq_max);
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.defrag_thold = params.defrag_thold;
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
// "inference" params, may affect the result
cparams.yarn_ext_factor = params.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
// everything else, not sure how to categorize them:
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
Also, since the model definition is now fully contained inside Currently, we can debug the cgraph using I'm not sure if the idea is worth exploring, but I can create a dedicated issue to discuss more if needed. |
@fairydreaming Yes, you should add a new |
Overall yes. The details are not yet clear to me completely - I think once the T5 encoder-decoder use case is implemented we will have a more clear picture and a starting point for multi-modal support. What I am trying to do is to be able to compose the Extending this analogy, a vision model is likely to fit in the same |
@ggerganov I think there's still one thing missing. There should be an abstract kv cache interface, |
ggml-ci
// if we have the output embeddings from the encoder, use them directly | ||
if (cross->t_embd) { | ||
inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd); | ||
|
||
return inp.cross_embd; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this change we now use directly the embeddings produced by the encoder (cross->t_embd
) as input for the decoder's cross-attention without downloading/uploading to/from RAM.
This seems to work correctly, but in debug it hits the following assert when I explicitly use -dev none
on Mac:
./bin/llama-cli \
-m ../models/google-t5-small/ggml-model-f16.gguf \
-p 'Translate from English to German: The house is wonderful.' \
-dev none
0.00.122.688 I llama_context_kv_self: constructing llama_context_kv_self
0.00.122.690 I init: kv_size = 4096, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 6, can_shift = 1
0.00.125.481 I init: CPU KV buffer size = 48.00 MiB
0.00.125.485 I llama_context_kv_self: KV self size = 48.00 MiB, K (f16): 24.00 MiB, V (f16): 24.00 MiB
0.00.138.987 I reserve: CPU compute buffer size = 30.00 MiB
0.00.138.988 I reserve: graph nodes = 197
0.00.138.988 I reserve: graph splits = 61 (with bs=512), 5 (with bs=1)
0.00.152.861 I reserve: CPU compute buffer size = 213.00 MiB
0.00.152.862 I reserve: graph nodes = 342
0.00.152.862 I reserve: graph splits = 98 (with bs=512), 18 (with bs=1)
0.00.152.869 I common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
0.00.152.869 W common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
Assertion failed: (tensor_alloc->offset == SIZE_MAX), function ggml_gallocr_init_tensor, file ggml-alloc.c, line 793.
Abort trap: 6
I think it is related to re-using the tensor from the encoder context, but I am not sure if the assert is correct in this case. @slaren Any ideas?
Edit: btw it does not hit the assert either without -dev none
or with -dev none -fa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure exactly what triggers the assert, probably because the graph didn't change except in that one tensor that previously was a view now it isn't, and ggml-alloc is not correctly detecting that the graph changed in an incompatible way. However, I don't think this is correct either way, because to do this you would need to allocate this tensor in a different buffer/sched, it's not possible to use tensors allocated in the compute buffer in the next graph, since the compute buffer is reset with each graph.
Overview
This PR is an intermediate step towards a more generic implementation that will support different underlying implementations of
llama_kv_cache
,llama_context
and the graph building logic (a.k.a.llm_build_context
). Thellama_kv_cache
is also introduced in the public API as an object, but it's actual functionality is yet to be defined in follow-up PRs.Currently, no functional changes have been introduced. Mainly the code has been reorganized in a way to allow implementing new abstractions. The main changes in the implementation are:
llama_kv_cache
inllm_build_context
. The goal is to be able to construct the computation graphs only through the abstractllama_context
interface, which will hide the actual KV cache implementation and thus allow to be overloaded based on the parameters of the specific use case. More generally, thellama_context
hides not only the KV cache implementation, but all the internal state (such as, applied adapters, masks, etc. if any) with the exception of the model weights - these are still available to thellm_build_context
in order to be able to construct the backbone graph of the various architectures.llama_kv_cache
inllama_decode
/llama_encode
. These are abstracted through a new objectllama_batch_manager
which is produced by the currentllama_context
. Again the goal is to not make explicit assumptions about the underlying KV cache implementation while processing the batches and be able to delegate this logic to thellama_context
. Thellama_batch_manager
is produced by thellama_context
and will handle logic such as, restoring the KV cache state to consistent state upon errors, batching the input batch into micro batches according to the internal processing logic, etc.llama_kv_cache
. In the future, these will be overloaded for the specific KV cache implementations through a common abstract interface.The modifications so far are quite substantial and touch too many lines. Even though the code is in a very intermediate state, with many members still publicly exposed and without proper object-oriented implementation in place, it should still be mergeable.
The general class hierarchy that I have in mind is like this:
Here,
llama_kv_cache_unified
is basically thellama_kv_cache
implementation that we currently have. In the future, we will add more implementations that would be appropriate for multi-user scenarios (e.g.llama_kv_cache_standard
) or for Mamba architectures (llama_kv_cache_mamba
).The base
llama_context
class will implement common functionality such as low-levelggml
buffer and backend management + adapters, without the notion of a KV cache. The derived classes will specialize thellama_context
for different use-cases.The
llm_build_context
would operate only through thellama_build_i
interface and the batch processing will respectively only interact with thellama_batch_manager_i
interface. The type ofllama_context
to construct in functions such asllama_init_from_model()
would be determined based on the model and the specified context parameters. For example, the user would be able to create bothllama_context_unified
andllama_context_standard
for aLLM_ARCH_QWEN2
model. Or allama_context_no_kv
for an encoding-onlyLLM_ARCH_BERT
model. And so on.API changes
The current changes are only necessary to make the API more consistent in following the naming convention. To migrate, simply replace the old API calls with the new ones.
llama_kv_cache_...
APIllama_kv_self_...
APIIn the future, the
llama_kv_cache_...
API will be changed to work withstruct llama_kv_cache
instead ofstruct llama_context
and the functionality will be extended to support things like saving, copying, loading, etc.Notes
build_qwen2vl
, inp_pos,lctx.n_pos_per_token
hackn_outputs
andn_outputs_enc
inllm_build_context
seem incorrectinp_s_seq
- not usedbatch.pos == NULL
-llama_context::pos_max()
is used incorrectlyllama_context
encode()/decode()
worst_case
from thellama_graph_i
API?PRs to resolve
New features